import { triplitClient } from "@main/triplit/client";
import logger from "@shared/logger/main-logger";
import type {
  CreateThreadData,
  Message,
  Thread,
  UpdateThreadData,
} from "@shared/triplit/types";
import { injectable } from "inversify";
import { BaseDbService } from "./base-db-service";

@injectable()
export class ThreadDbService extends BaseDbService {
  constructor() {
    super("threads");
  }

  async insertThread(thread: CreateThreadData): Promise<Thread> {
    return await triplitClient.insert("threads", thread);
  }

  async deleteThread(threadId: string) {
    await triplitClient.delete("threads", threadId);
  }

  async updateThread(
    threadId: string,
    updateData?: Omit<UpdateThreadData, "updatedAt">,
    shouldUpdateTimestamp: boolean = false,
  ) {
    try {
      await triplitClient.update("threads", threadId, async (thread) => {
        if (updateData) {
          Object.assign(thread, updateData);
        }

        if (shouldUpdateTimestamp) {
          thread.updatedAt = new Date();
        }
      });
    } catch (error) {
      logger.error("ThreadDbService:updateThread error", { error });
      throw error;
    }
  }

  async getThreadById(threadId: string): Promise<Thread | null> {
    return await triplitClient.fetchById("threads", threadId);
  }

  async getThreads(): Promise<Thread[]> {
    const query = triplitClient.query("threads");
    const threads = await triplitClient.fetch(query);
    return threads;
  }

  async deleteAllThreads(): Promise<string[]> {
    const threadsQuery = triplitClient
      .query("threads")
      .Where("collected", "=", false);
    const threads = await triplitClient.fetch(threadsQuery);

    await triplitClient.transact(async (tx) => {
      const deletePromises = threads.map((thread) =>
        tx.delete("threads", thread.id),
      );

      await Promise.all(deletePromises);
    });

    return threads.map((thread) => thread.id);
  }

  async getMessagesByThreadId(threadId: string): Promise<Message[]> {
    try {
      const threadsQuery = triplitClient
        .query("threads")
        .Where("id", "=", threadId)
        .Include("messages");
      const thread = await triplitClient.fetchOne(threadsQuery);
      return thread?.messages ?? [];
    } catch (error) {
      logger.error("ThreadDbService:getMessagesByThreadId error", { error });
      return [];
    }
  }

  async getTitleSummaryParams(threadId: string): Promise<{
    modelName: string;
    providerId: string;
    messages: Message[];
  }> {
    try {
      const threadsQuery = triplitClient
        .query("threads")
        .Where("id", "=", threadId)
        .Select(["providerId"])
        .Include("messages")
        .Include("model");
      const thread = await triplitClient.fetchOne(threadsQuery);

      if (!thread || !thread.model || thread.messages.length === 0) {
        logger.error("ThreadDbService:getTitleSummaryParams thread not found", {
          threadId,
        });
        throw new Error("Thread not found");
      }

      return {
        modelName: thread.model.name,
        providerId: thread.providerId,
        messages: thread.messages,
      };
    } catch (error) {
      logger.error("ThreadDbService:getTitleSummaryParams error", { error });
      throw error;
    }
  }
}
